# imports
from sklearn.preprocessing import StandardScaler
import numpy as np
import pandas as pd
import copy
from sklearn.linear_model import LassoCV
from sklearn.model_selection import GroupKFold
import pickle
import scipy as sp
from scipy.stats import expon
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.linear_model import LinearRegression
import joblib

# covariance matrix
def new_cov_matrix(cov):
    p = cov.shape[0]
    # get eigen value and eigen vectors
    e_val, e_vec = sp.linalg.eigh(cov)
    start = [0, 35, 77, 86]
    end = [35, 77, 86, p]
    e_val_new = np.array([])
    for i, j in zip(start, end):
        e_val_new = np.append(e_val_new, linear_approximation(i, j, e_val))
    # simulate eigen vectors
    e_vec_new = np.zeros_like(e_vec)
    for i in range(p):
        w = np.zeros(p)  # , np.random.normal(0.01, 0.01, size=p)
        w[np.random.choice(p, 6)] += np.random.normal(0.01, 0.06, size=(6))
        e_vec_new[:, i] = w / np.linalg.norm(w)
    # keep the top 4 eigen value and corresponding eigen vector
    e_vec_new[:, -4:] = e_vec[:, -4:]
    e_val_new[-4:] = e_val[-4:]
    # replace the negative eigen values
    e_val_new[np.where(e_val_new < 0)] = e_val[np.where(e_val_new < 0)]
    # generate a new covariance matrix
    cov_new = e_vec_new.dot(np.diag(e_val_new)).dot(e_vec_new.T)
    return cov_new

# get linear approximation of eigen values


def linear_approximation(start, end, e_val):
    est = LinearRegression()
    X = np.arange(start, end).reshape(-1, 1)
    est.fit(X, e_val[start:end])
    pred = est.predict(X)
    return pred


# coefs
def generate_coefs(index, columns):
    simulated_coefs_df = pd.DataFrame(0, index=index, columns=columns)
    # get the indices of each group of features
    ind_demo = [columns.index(col) for col in columns if "demo" in col]
    ind_proxy = [columns.index(col) for col in columns if "proxy" in col]
    ind_investment = [columns.index(col)
                      for col in columns if "investment" in col]

    for i in range(7):
        outcome_name = simulated_coefs_df.index[i]
        if "proxy" in outcome_name:
            ind_same_proxy = [
                ind for ind in ind_proxy if outcome_name in columns[ind]]
            # print(ind_same_proxy)
            random_proxy_name = np.random.choice(
                [proxy for proxy in index[:4] if proxy != outcome_name]
            )
            ind_random_other_proxy = [
                ind for ind in ind_proxy if random_proxy_name in columns[ind]
            ]
            # demo
            simulated_coefs_df.iloc[
                i, np.random.choice(ind_demo, 2)
            ] = np.random.uniform(0.004, 0.05)
            # same proxy
            simulated_coefs_df.iloc[i, ind_same_proxy] = sorted(
                np.random.choice(expon.pdf(np.arange(10))
                                 * 5e-1, 6, replace=False)
            )
            simulated_coefs_df.iloc[i, ind_random_other_proxy] = sorted(
                np.random.choice(expon.pdf(np.arange(10))
                                 * 5e-2, 6, replace=False)
            )
        elif "investment" in outcome_name:
            ind_same_invest = [
                ind for ind in ind_investment if outcome_name in columns[ind]
            ]
            random_proxy_name = np.random.choice(index[:4])
            ind_random_other_proxy = [
                ind for ind in ind_proxy if random_proxy_name in columns[ind]
            ]
            simulated_coefs_df.iloc[
                i, np.random.choice(ind_demo, 2)
            ] = np.random.uniform(0.001, 0.05)
            simulated_coefs_df.iloc[i, ind_same_invest] = sorted(
                np.random.choice(expon.pdf(np.arange(10))
                                 * 5e-1, 6, replace=False)
            )
            simulated_coefs_df.iloc[i, ind_random_other_proxy] = sorted(
                np.random.choice(expon.pdf(np.arange(10))
                                 * 1e-1, 6, replace=False)
            )
    return simulated_coefs_df

# residuals


def simulate_residuals(ind):
    n, n_pos, n_neg = joblib.load("n_{}.jbl".format(ind))
    # gmm
    est = joblib.load("gm_{}.jbl".format(ind))
    x_new = est.sample(n - n_pos - n_neg)[0].flatten()

    # log normal on outliers
    if n_pos > 0:
        # positive outliers
        s, loc, scale = joblib.load("lognorm_pos_{}.jbl".format(ind))
        fitted_pos_outliers = sp.stats.lognorm(
            s, loc=loc, scale=scale).rvs(size=n_pos)
    else:
        fitted_pos_outliers = np.array([])
    # negative outliers
    if n_neg > 0:
        s, loc, scale = joblib.load("lognorm_neg_{}.jbl".format(ind))
        fitted_neg_outliers = - \
            sp.stats.lognorm(s, loc=loc, scale=scale).rvs(size=n_neg)
    else:
        fitted_neg_outliers = np.array([])
    x_new = np.concatenate((x_new, fitted_pos_outliers, fitted_neg_outliers))
    return x_new


def simulate_residuals_all(res_df):
    res_df_new = res_df.copy(deep=True)
    for i in range(res_df.shape[1]):
        res_df_new.iloc[:, i] = simulate_residuals(i)
    # demean the new residual again
    res_df_new = res_df_new - res_df_new.mean(axis=0)
    return res_df_new


# generate data
def get_prediction(df, coef_matrix, residuals, thetas, n, intervention, columns, index, counterfactual):
    data_matrix = df[columns].values
    # sample residuals
    sample_residuals = residuals
    preds = np.matmul(data_matrix, coef_matrix.T)

    # get prediction for current investment
    if counterfactual:
        pred_inv = np.zeros(preds[:, 4:].shape)
    else:
        pred_inv = preds[:, 4:] + sample_residuals[:, 4:] + intervention
    df[index[4:]] = pd.DataFrame(pred_inv, index=df.index)

    # get prediction for current proxy
    pred_proxy = preds[:, :4] + sample_residuals[:, :4] + \
        np.matmul(pred_inv, thetas.T)
    df[index[:4]] = pd.DataFrame(pred_proxy, index=df.index)
    return df


def generate_dgp(
    cov_matrix,
    n_tpid,
    t_period,
    coef_matrix,
    residual_matrix,
    thetas,
    intervention,
    columns,
    index,
    counterfactual
):
    df_all = pd.DataFrame()
    # get first period prediction
    m = cov_matrix.shape[0]
    x = np.random.multivariate_normal(np.repeat(0, m), cov_matrix, size=n_tpid)
    df = pd.DataFrame(
        np.hstack(
            (np.arange(n_tpid).reshape(-1, 1),
             np.repeat(1, n_tpid).reshape(-1, 1), x)
        ),
        columns=["id", "datetime"] + columns,
    )
    df = get_prediction(df, coef_matrix, residual_matrix[0],
                        thetas, n_tpid, intervention, columns, index, False)
    df_all = pd.concat([df_all, df], axis=0)

    # iterate the step ahead contruction
    for t in range(2, t_period + 1):
        # prepare new x
        new_df = df.copy(deep=True)
        new_df["datetime"] = np.repeat(t, n_tpid)
        for name in index:
            for i in range(-6, -1):
                new_df[f"{name}_{i}"] = df[f"{name}_{i+1}"]
            new_df[f"{name}_-1"] = df[name]
        df = get_prediction(new_df, coef_matrix, residual_matrix[t - 1],
                            thetas, n_tpid, [0, 0, 0], columns, index, counterfactual)
        df_all = pd.concat([df_all, df])
    df_all = df_all.sort_values(["id", "datetime"])
    return df_all

def pretty_print_summary(summary, treatment_names):
    """Print DML summary as a formatted dataframe."""
    if len(treatment_names) > 1:
        df = (
            pd.DataFrame(
                np.array(summary.tables[0].data)[:, 1:].T,
                columns=["Property", "Treatment", "Value"],
            )
            .groupby("Treatment")
            .apply(
                lambda d: d.drop(columns="Treatment").set_index(
                    "Property").transpose()
            )
            .droplevel(1)
        )
    else:
        df = pd.DataFrame(
            np.array(summary.tables[0].data)[
                :, 1:].T, columns=["Property", "Value"]
        )
        df["Treatment"] = treatment_names[0]
        df = (
            df.groupby("Treatment")
            .apply(
                lambda d: d.drop(columns="Treatment").set_index(
                    "Property").transpose()
            )
            .droplevel(1)
        )
    df.rename(index=dict(zip(df.index.values, treatment_names)), inplace=True)
    return df